#!/bin/bash

: "${CONT:?Base Container image is not set, please specify CONT envvar}"
: "${DATA:?Data directory is not set, please specify DATA envvar}"
: "${CKPT:?Checkpoint directory is not set, please specify CKPT envvar}"
: "${NODES:?Number of nodes is not set, please specify NODES envvar}"
: "${OUTPUT:?Output directory is not set, please specify OUTPUT envvar}"

CONT_MOUNTS="${DATA}:/app/dataset:ro,${CKPT}:/app/checkpoints:ro,${OUTPUT}:/results"

: "${MASTER_PORT:=29500}"
export MASTER_PORT
export MASTER_ADDR="$(scontrol show hostnames "${SLURM_JOB_NODELIST-}" | head -n1)"

srun -l --kill-on-bad-exit=0 --mpi="${SLURM_MPI_TYPE:-pmix}" \
         --ntasks="$(( NODES * ${GPUS:-8} ))" \
         --ntasks-per-node="${GPUS:-8}" \
         --container-image="${CONT}"  \
         --container-mounts="${CONT_MOUNTS}" \
         --container-env=MASTER_PORT,MASTER_ADDR \
             slurm2pytorch python /app/training/run_clm.py output_dir=/results \
                  dataset.train_dataset_path=/app/dataset dataset.eval_dataset_path=/app/dataset \
                
